{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "package_dir =os.path.abspath('').replace(\"/flashface/all_finetune\",\"\")\n",
    "sys.path.insert(0, package_dir)\n",
    "import os\n",
    "\n",
    "import copy\n",
    "import random\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "import torch.cuda.amp as amp\n",
    "import torch.nn as nn\n",
    "import torchvision.transforms as T\n",
    "\n",
    "from config import cfg\n",
    "\n",
    "from PIL import Image\n",
    "\n",
    "\n",
    "import copy\n",
    "\n",
    "import os\n",
    "\n",
    "import random\n",
    "import sys\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "import torch.cuda.amp as amp\n",
    "\n",
    "import torchvision.transforms as T\n",
    "from config import cfg\n",
    "from models import sd_v1_ref_unet\n",
    "from ops.context_diffusion import ContextGaussianDiffusion\n",
    "\n",
    "from PIL import Image, ImageDraw\n",
    "\n",
    "\n",
    "from ldm import  data, models, ops\n",
    "from ldm.models.vae import sd_v1_vae\n",
    "\n",
    "import torchvision.transforms as T\n",
    "from utils import Compose, PadToSquare, get_padding, seed_everything\n",
    "from ldm.models.retinaface import retinaface, crop_face"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "# model path\n",
    "SKIP_LOAD = False\n",
    "DEBUG_VIEW = False\n",
    "SKEP_LOAD = False\n",
    "LOAD_FLAG = True\n",
    "DEFAULT_INPUT_IMAGES = 4\n",
    "MAX_INPUT_IMAGES = 4\n",
    "SIZE = 768\n",
    "with_lora = False\n",
    "enable_encoder = False\n",
    "with_pos_mask = True\n",
    "\n",
    "weight_path = f'{package_dir}/cache/flashface.ckpt'\n",
    "\n",
    "gpu = 'cuda'\n",
    "\n",
    "padding_to_square = PadToSquare(224)\n",
    "\n",
    "retinaface_transforms = T.Compose([PadToSquare(size=640), T.ToTensor()])\n",
    "\n",
    "retinaface = retinaface(pretrained=True,\n",
    "                        device='cuda').eval().requires_grad_(False)\n",
    "\n",
    "\n",
    "def detect_face(imgs=None):\n",
    "\n",
    "    # read images\n",
    "    pil_imgs = imgs\n",
    "    b = len(pil_imgs)\n",
    "    vis_pil_imgs = copy.deepcopy(pil_imgs)\n",
    "\n",
    "    # detection\n",
    "    imgs = torch.stack([retinaface_transforms(u) for u in pil_imgs]).to(gpu)\n",
    "    boxes, kpts = retinaface.detect(imgs, min_thr=0.6)\n",
    "\n",
    "    # undo padding and scaling\n",
    "    face_imgs = []\n",
    "\n",
    "    for i in range(b):\n",
    "        # params\n",
    "        scale = 640 / max(pil_imgs[i].size)\n",
    "        left, top, _, _ = get_padding(round(scale * pil_imgs[i].width),\n",
    "                                      round(scale * pil_imgs[i].height), 640)\n",
    "\n",
    "        # undo padding\n",
    "        boxes[i][:, [0, 2]] -= left\n",
    "        boxes[i][:, [1, 3]] -= top\n",
    "        kpts[i][:, :, 0] -= left\n",
    "        kpts[i][:, :, 1] -= top\n",
    "\n",
    "        # undo scaling\n",
    "        boxes[i][:, :4] /= scale\n",
    "        kpts[i][:, :, :2] /= scale\n",
    "\n",
    "        # crop faces\n",
    "        crops = crop_face(pil_imgs[i], boxes[i], kpts[i])\n",
    "        if len(crops) != 1:\n",
    "            raise ValueError(\n",
    "                f'Warning: {len(crops)} faces detected in image {i}')\n",
    "\n",
    "        face_imgs += crops\n",
    "\n",
    "        # draw boxes on the pil image\n",
    "        draw = ImageDraw.Draw(vis_pil_imgs[i])\n",
    "        for box in boxes[i]:\n",
    "            box = box[:4].tolist()\n",
    "            box = [int(x) for x in box]\n",
    "            draw.rectangle(box, outline='red', width=4)\n",
    "\n",
    "    face_imgs = face_imgs\n",
    "\n",
    "    return face_imgs\n",
    "\n",
    "\n",
    "if not DEBUG_VIEW and not SKEP_LOAD:\n",
    "    clip_tokenizer = data.CLIPTokenizer(padding='eos')\n",
    "    clip = getattr(models, cfg.clip_model)(\n",
    "        pretrained=True).eval().requires_grad_(False).textual.to(gpu)\n",
    "    autoencoder = sd_v1_vae(\n",
    "        pretrained=True).eval().requires_grad_(False).to(gpu)\n",
    "\n",
    "    unet = sd_v1_ref_unet(pretrained=True,\n",
    "                          version='sd-v1-5_nonema',\n",
    "                          enable_encoder=enable_encoder).to(gpu)\n",
    "\n",
    "    unet.replace_input_conv()\n",
    "    unet = unet.eval().requires_grad_(False).to(gpu)\n",
    "    unet.share_cache['num_pairs'] = cfg.num_pairs\n",
    "\n",
    "\n",
    "    if LOAD_FLAG:\n",
    "        model_weight = torch.load(weight_path, map_location=\"cpu\")\n",
    "        msg = unet.load_state_dict(model_weight, strict=True)\n",
    "        print(msg)\n",
    "\n",
    "    # diffusion\n",
    "    sigmas = ops.noise_schedule(schedule=cfg.schedule,\n",
    "                                n=cfg.num_timesteps,\n",
    "                                beta_min=cfg.scale_min,\n",
    "                                beta_max=cfg.scale_max)\n",
    "    diffusion = ContextGaussianDiffusion(sigmas=sigmas,\n",
    "                                         prediction_type=cfg.prediction_type)\n",
    "    diffusion.num_pairs = cfg.num_pairs\n",
    "    print(\"model initialized\")\n",
    "\n",
    "face_transforms = Compose(\n",
    "    [T.ToTensor(),\n",
    "     T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])\n",
    "\n",
    "\n",
    "def encode_text(m, x):\n",
    "    # embeddings\n",
    "    x = m.token_embedding(x) + m.pos_embedding\n",
    "\n",
    "    # transformer\n",
    "    for block in m.transformer:\n",
    "        x = block(x)\n",
    "\n",
    "    # output\n",
    "    x = m.norm(x)\n",
    "\n",
    "    return x\n",
    "\n",
    "\n",
    "def generate(\n",
    "    pos_prompt,\n",
    "    neg_prompt,\n",
    "    steps=35,\n",
    "    face_bbox=[0.3, 0.1, 0.6, 0.4],\n",
    "    lamda_feat=1.2,\n",
    "    face_guidence=3.2,\n",
    "    num_sample=1,\n",
    "    text_control_scale=7.5,\n",
    "    seed=0,\n",
    "    step_to_launch_face_guidence=750,\n",
    "    lamda_feat_before_ref_guidence=0.85,\n",
    "    reference_faces=None,\n",
    "    need_detect=True,\n",
    "    default_pos_prompt='best quality, masterpiece,ultra-detailed, UHD 4K, photographic',\n",
    "    default_neg_prompt='blurry, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face',\n",
    "):\n",
    "    solver = 'ddim'\n",
    "    if default_pos_prompt is not None:\n",
    "        pos_prompt = pos_prompt + ', ' + default_pos_prompt\n",
    "    if neg_prompt is not None:\n",
    "        neg_prompt = neg_prompt + ', ' + default_neg_prompt\n",
    "    else:\n",
    "        neg_prompt = default_neg_prompt\n",
    "    if seed == -1:\n",
    "        seed = random.randint(0, 2147483647)\n",
    "    seed_everything(seed)    \n",
    "    \n",
    "    print(seed)\n",
    "    print('final pos_prompt: ', pos_prompt)\n",
    "    print('final neg_prompt: ', neg_prompt)\n",
    "\n",
    "    if need_detect:\n",
    "        reference_faces = detect_face(reference_faces)\n",
    "\n",
    "        # for i, ref_img in enumerate(reference_faces):\n",
    "        #     ref_img.save(f'./{i + 1}.png')\n",
    "        print(f'detected {len(reference_faces)} faces')\n",
    "        assert len(\n",
    "            reference_faces) > 0, 'No face detected in the reference images'\n",
    "\n",
    "        if len(reference_faces) < 4:\n",
    "            expand_reference_faces = copy.deepcopy(reference_faces)\n",
    "            while len(expand_reference_faces) < 4:\n",
    "                # random select from ref_imgs\n",
    "                expand_reference_faces.append(random.choice(reference_faces))\n",
    "            reference_faces = expand_reference_faces\n",
    "\n",
    "    # process the ref_imgs\n",
    "    H = W = 768\n",
    "\n",
    "    normalized_bbox = face_bbox\n",
    "    print(normalized_bbox)\n",
    "    face_bbox = [\n",
    "        int(normalized_bbox[0] * W),\n",
    "        int(normalized_bbox[1] * H),\n",
    "        int(normalized_bbox[2] * W),\n",
    "        int(normalized_bbox[3] * H)\n",
    "    ]\n",
    "    max_size = max(face_bbox[2] - face_bbox[1], face_bbox[3] - face_bbox[1])\n",
    "    empty_mask = torch.zeros((H, W))\n",
    "\n",
    "    empty_mask[face_bbox[1]:face_bbox[1] + max_size,\n",
    "               face_bbox[0]:face_bbox[0] + max_size] = 1\n",
    "\n",
    "    empty_mask = empty_mask[::8, ::8].cuda()\n",
    "    empty_mask = empty_mask[None].repeat(num_sample, 1, 1)\n",
    "\n",
    "    pasted_ref_faces = []\n",
    "    show_refs = []\n",
    "    for ref_img in reference_faces:\n",
    "        ref_img = ref_img.convert('RGB')\n",
    "        ref_img = padding_to_square(ref_img)\n",
    "        to_paste = ref_img\n",
    "\n",
    "        to_paste = face_transforms(to_paste)\n",
    "        pasted_ref_faces.append(to_paste)\n",
    "\n",
    "    faces = torch.stack(pasted_ref_faces, dim=0).to(gpu)\n",
    "\n",
    "    c = encode_text(clip, clip_tokenizer([pos_prompt]).to(gpu))\n",
    "    c = c[None].repeat(num_sample, 1, 1, 1).flatten(0, 1)\n",
    "    c = {'context': c}\n",
    "\n",
    "    single_null_context = encode_text(clip,\n",
    "                                      clip_tokenizer([neg_prompt\n",
    "                                                      ]).cuda()).to(gpu)\n",
    "    null_context = single_null_context\n",
    "    nc = {\n",
    "        'context': null_context[None].repeat(num_sample, 1, 1,\n",
    "                                             1).flatten(0, 1)\n",
    "    }\n",
    "\n",
    "    ref_z0 = cfg.ae_scale * torch.cat([\n",
    "        autoencoder.sample(u, deterministic=True)\n",
    "        for u in faces.split(cfg.ae_batch_size)\n",
    "    ])\n",
    "    #  ref_z0 = ref_z0[None].repeat(num_sample, 1,1,1,1).flatten(0,1)\n",
    "    unet.share_cache['num_pairs'] = 4\n",
    "    unet.share_cache['ref'] = ref_z0\n",
    "    unet.share_cache['similarity'] = torch.tensor(lamda_feat).cuda()\n",
    "    unet.share_cache['ori_similarity'] = torch.tensor(lamda_feat).cuda()\n",
    "    unet.share_cache['lamda_feat_before_ref_guidence'] = torch.tensor(\n",
    "        lamda_feat_before_ref_guidence).cuda()\n",
    "    unet.share_cache['ref_context'] = single_null_context.repeat(\n",
    "        len(ref_z0), 1, 1)\n",
    "    unet.share_cache['masks'] = empty_mask\n",
    "    unet.share_cache['classifier'] = face_guidence\n",
    "    unet.share_cache['step_to_launch_face_guidence'] = step_to_launch_face_guidence\n",
    "\n",
    "    diffusion.classifier = face_guidence\n",
    "    # sample\n",
    "    with amp.autocast(dtype=cfg.flash_dtype), torch.no_grad():\n",
    "        z0 = diffusion.sample(solver=solver,\n",
    "                              noise=torch.empty(num_sample, 4,\n",
    "                                                768 // 8,\n",
    "                                                768 // 8,\n",
    "                                                device=gpu).normal_(),\n",
    "                              model=unet,\n",
    "                              model_kwargs=[c, nc],\n",
    "                              steps=steps,\n",
    "                              guide_scale=text_control_scale,\n",
    "                              guide_rescale=0.5,\n",
    "                              show_progress=True,\n",
    "                              discretization=cfg.discretization)\n",
    "\n",
    "    imgs = autoencoder.decode(z0 / cfg.ae_scale)\n",
    "    del unet.share_cache['ori_similarity']\n",
    "    # output\n",
    "    imgs = (imgs.permute(0, 2, 3, 1) * 127.5 + 127.5).cpu().numpy().clip(\n",
    "        0, 255).astype(np.uint8)\n",
    "\n",
    "    # convert to PIL image\n",
    "    imgs = [Image.fromarray(img) for img in imgs]\n",
    "    imgs = imgs + show_refs\n",
    "\n",
    "    return imgs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Recommended hyper-parameters to obtain stable ID Fidelity\n",
    "face_imgs = [Image.open(f\"{package_dir}/examples/age/{i+1}.png\").convert(\"RGB\") for i in range(3)]\n",
    "need_detect = True\n",
    "pos_prompt = 'A beautiful young asian woman, in a traditional chinese outfit, long hair, complete with a classic hairpin, on the street , white skin, soft light'\n",
    "num_samples = 4\n",
    "# center face position\n",
    "face_bbox =[0.3, 0.2, 0.6, 0.5] \n",
    "# bigger these three parameters leads to more fidelity but less diversity \n",
    "lamda_feat = 1.2\n",
    "face_guidence = 3.2\n",
    "step_to_launch_face_guidence = 750\n",
    "\n",
    "steps = 25\n",
    "default_text_control_scale = 7.5\n",
    "\n",
    "default_seed = 0\n",
    "\n",
    "\n",
    "imgs = generate(pos_prompt=pos_prompt, \n",
    "                    neg_prompt=None, \n",
    "                    steps=steps, \n",
    "                    face_bbox=face_bbox,\n",
    "                    lamda_feat=lamda_feat, \n",
    "                    face_guidence=face_guidence, \n",
    "                    num_sample=num_samples, \n",
    "                    text_control_scale=default_text_control_scale, \n",
    "                    seed=default_seed, \n",
    "                    step_to_launch_face_guidence=step_to_launch_face_guidence, \n",
    "                    reference_faces=face_imgs,\n",
    "                    need_detect=need_detect\n",
    "                    )\n",
    "\n",
    "\n",
    "# show the generated images\n",
    "img_size = imgs[0].size\n",
    "num_imgs = len(imgs)\n",
    "save_img = Image.new('RGB', (img_size[0] * (num_imgs + 1), img_size[1]))\n",
    "for i, img in enumerate(imgs):\n",
    "    save_img.paste(img, ((i + 1) * img_size[0], 0))\n",
    "\n",
    "# paste all four reference face imgs to the first\n",
    "\n",
    "resize_w = img_size[0] // 2\n",
    "resize_h = img_size[1] // 2\n",
    "\n",
    "for id, ref_img in enumerate(face_imgs):\n",
    "    # resize the ref_img keep the ratio to fit the size of (resize_w, resize_h)\n",
    "    w_ratio = resize_w / ref_img.size[0]\n",
    "    h_ratio = resize_h / ref_img.size[1]\n",
    "    ratio = min(w_ratio, h_ratio)\n",
    "    ref_img = ref_img.resize(\n",
    "        (int(ref_img.size[0] * ratio), int(ref_img.size[1] * ratio)))\n",
    "\n",
    "    if id < 2:\n",
    "        save_img.paste(ref_img, (id * resize_w, 0))\n",
    "    else:\n",
    "        save_img.paste(ref_img, ((id - 2) * resize_w, resize_h))\n",
    "\n",
    "display(save_img)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Recommended hyper-parameters to obtain stable ID Fidelity\n",
    "face_imgs = [Image.open(f\"{package_dir}/examples/age/{i+1}.png\").convert(\"RGB\") for i in range(3)]\n",
    "need_detect = True\n",
    "pos_prompt = 'A very  old woman with short wavy hair'\n",
    "num_samples = 4\n",
    "# center face position\n",
    "face_bbox =[0.3, 0.1, 0.6, 0.4] \n",
    "# bigger these three parameters leads to more fidelity but less diversity \n",
    "lamda_feat = 1\n",
    "face_guidence = 2.5\n",
    "step_to_launch_face_guidence = 750\n",
    "\n",
    "steps = 25\n",
    "default_text_control_scale = 8.5\n",
    "\n",
    "default_seed = 0\n",
    "\n",
    "\n",
    "imgs = generate(pos_prompt=pos_prompt, \n",
    "                    neg_prompt=None, \n",
    "                    steps=steps, \n",
    "                    face_bbox=face_bbox,\n",
    "                    lamda_feat=lamda_feat, \n",
    "                    face_guidence=face_guidence, \n",
    "                    num_sample=num_samples, \n",
    "                    text_control_scale=default_text_control_scale, \n",
    "                    seed=default_seed, \n",
    "                    step_to_launch_face_guidence=step_to_launch_face_guidence, \n",
    "                    reference_faces=face_imgs,\n",
    "                    need_detect=need_detect\n",
    "                    )\n",
    "\n",
    "\n",
    "# show the generated images\n",
    "img_size = imgs[0].size\n",
    "num_imgs = len(imgs)\n",
    "save_img = Image.new('RGB', (img_size[0] * (num_imgs + 1), img_size[1]))\n",
    "for i, img in enumerate(imgs):\n",
    "    save_img.paste(img, ((i + 1) * img_size[0], 0))\n",
    "\n",
    "# paste all four reference face imgs to the first\n",
    "\n",
    "resize_w = img_size[0] // 2\n",
    "resize_h = img_size[1] // 2\n",
    "\n",
    "for id, ref_img in enumerate(face_imgs):\n",
    "    # resize the ref_img keep the ratio to fit the size of (resize_w, resize_h)\n",
    "    w_ratio = resize_w / ref_img.size[0]\n",
    "    h_ratio = resize_h / ref_img.size[1]\n",
    "    ratio = min(w_ratio, h_ratio)\n",
    "    ref_img = ref_img.resize(\n",
    "        (int(ref_img.size[0] * ratio), int(ref_img.size[1] * ratio)))\n",
    "\n",
    "    if id < 2:\n",
    "        save_img.paste(ref_img, (id * resize_w, 0))\n",
    "    else:\n",
    "        save_img.paste(ref_img, ((id - 2) * resize_w, resize_h))\n",
    "\n",
    "display(save_img)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "face_imgs = [\n",
    "    Image.open(f\"{package_dir}/examples/age/{i+1}.png\").convert(\"RGB\") for i in range(3)\n",
    "]\n",
    "need_detect = True\n",
    "\n",
    "pos_prompt = \"\"\"The cute, beautiful baby girl with medium length brown hair and  pink bow, in the studio \"\"\"\n",
    "# remove beard\n",
    "neg_prompt = None\n",
    "# No face position\n",
    "face_bbox = [0.3, 0.2, 0.6, 0.6]\n",
    "\n",
    "\n",
    "# bigger these three parameters leads to more fidelity but less diversity\n",
    "lamda_feat = 1.2\n",
    "face_guidence = 2\n",
    "step_to_launch_face_guidence = 700\n",
    "\n",
    "steps = 50\n",
    "default_text_control_scale = 7.5\n",
    "\n",
    "default_seed = 0\n",
    "\n",
    "\n",
    "imgs = generate(\n",
    "    pos_prompt=pos_prompt,\n",
    "    neg_prompt=neg_prompt,\n",
    "    steps=steps,\n",
    "    face_bbox=face_bbox,\n",
    "    lamda_feat=lamda_feat,\n",
    "    face_guidence=face_guidence,\n",
    "    num_sample=4,\n",
    "    text_control_scale=default_text_control_scale,\n",
    "    seed=default_seed,\n",
    "    step_to_launch_face_guidence=step_to_launch_face_guidence,\n",
    "    reference_faces=face_imgs,\n",
    "    need_detect=need_detect,\n",
    ")\n",
    "\n",
    "\n",
    "# show the generated images\n",
    "img_size = imgs[0].size\n",
    "num_imgs = len(imgs)\n",
    "save_img = Image.new(\"RGB\", (img_size[0] * (num_imgs + 1), img_size[1]))\n",
    "for i, img in enumerate(imgs):\n",
    "    save_img.paste(img, ((i + 1) * img_size[0], 0))\n",
    "\n",
    "# paste all four reference face imgs to the first\n",
    "\n",
    "resize_w = img_size[0] // 2\n",
    "resize_h = img_size[1] // 2\n",
    "\n",
    "for id, ref_img in enumerate(face_imgs):\n",
    "    # resize the ref_img keep the ratio to fit the size of (resize_w, resize_h)\n",
    "    w_ratio = resize_w / ref_img.size[0]\n",
    "    h_ratio = resize_h / ref_img.size[1]\n",
    "    ratio = min(w_ratio, h_ratio)\n",
    "    ref_img = ref_img.resize(\n",
    "        (int(ref_img.size[0] * ratio), int(ref_img.size[1] * ratio))\n",
    "    )\n",
    "\n",
    "    if id < 2:\n",
    "        save_img.paste(ref_img, (id * resize_w, 0))\n",
    "    else:\n",
    "        save_img.paste(ref_img, ((id - 2) * resize_w, resize_h))\n",
    "\n",
    "display(save_img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "face_imgs = [Image.open(f\"{package_dir}/examples/avatar.png\").convert(\"RGB\")]\n",
    "need_detect = True\n",
    "pos_prompt = \"A handsome young man with long brown hair is sitting in the desert\"\n",
    "num_samples = 2\n",
    "# No face position\n",
    "face_bbox =[0., 0., 0., 0.] \n",
    "# bigger these three parameters leads to more fidelity but less diversity \n",
    "lamda_feat = 0.9\n",
    "face_guidence = 2.5\n",
    "step_to_launch_face_guidence = 700\n",
    "\n",
    "steps = 50\n",
    "default_text_control_scale = 7.5\n",
    "\n",
    "default_seed = 0\n",
    "\n",
    "\n",
    "imgs = generate(pos_prompt=pos_prompt, \n",
    "                    neg_prompt=None, \n",
    "                    steps=steps, \n",
    "                    face_bbox=face_bbox,\n",
    "                    lamda_feat=lamda_feat, \n",
    "                    face_guidence=face_guidence, \n",
    "                    num_sample=num_samples, \n",
    "                    text_control_scale=default_text_control_scale, \n",
    "                    seed=default_seed, \n",
    "                    step_to_launch_face_guidence=step_to_launch_face_guidence, \n",
    "                    reference_faces=face_imgs,\n",
    "                    need_detect=need_detect\n",
    "                    )\n",
    "\n",
    "\n",
    "# show the generated images\n",
    "img_size = imgs[0].size\n",
    "num_imgs = len(imgs)\n",
    "save_img = Image.new('RGB', (img_size[0] * (num_imgs + 1), img_size[1]))\n",
    "for i, img in enumerate(imgs):\n",
    "    save_img.paste(img, ((i + 1) * img_size[0], 0))\n",
    "\n",
    "# paste all four reference face imgs to the first\n",
    "\n",
    "resize_w = img_size[0] // 2\n",
    "resize_h = img_size[1] // 2\n",
    "\n",
    "for id, ref_img in enumerate(face_imgs):\n",
    "    # resize the ref_img keep the ratio to fit the size of (resize_w, resize_h)\n",
    "    w_ratio = resize_w / ref_img.size[0]\n",
    "    h_ratio = resize_h / ref_img.size[1]\n",
    "    ratio = min(w_ratio, h_ratio)\n",
    "    ref_img = ref_img.resize(\n",
    "        (int(ref_img.size[0] * ratio), int(ref_img.size[1] * ratio)))\n",
    "\n",
    "    if id < 2:\n",
    "        save_img.paste(ref_img, (id * resize_w, 0))\n",
    "    else:\n",
    "        save_img.paste(ref_img, ((id - 2) * resize_w, resize_h))\n",
    "\n",
    "display(save_img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "face_imgs = [Image.open(f\"{package_dir}/examples/snow_white.png\").convert(\"RGB\")]\n",
    "need_detect = True\n",
    "pos_prompt = \"Full body photo of a beautiful young women sitting in the office, medium length wavy hair, wearinig red bow hairpin on the top of head\"\n",
    "num_samples = 2\n",
    "# No face position\n",
    "face_bbox =[0., 0., 0., 0.] \n",
    "# bigger these three parameters leads to more fidelity but less diversity \n",
    "lamda_feat = 1\n",
    "face_guidence = 2\n",
    "step_to_launch_face_guidence = 600\n",
    "\n",
    "steps = 50\n",
    "default_text_control_scale = 7.5\n",
    "\n",
    "default_seed = 0\n",
    "\n",
    "\n",
    "imgs = generate(pos_prompt=pos_prompt, \n",
    "                    neg_prompt=None, \n",
    "                    steps=steps, \n",
    "                    face_bbox=face_bbox,\n",
    "                    lamda_feat=lamda_feat, \n",
    "                    face_guidence=face_guidence, \n",
    "                    num_sample=num_samples, \n",
    "                    text_control_scale=default_text_control_scale, \n",
    "                    seed=default_seed, \n",
    "                    step_to_launch_face_guidence=step_to_launch_face_guidence, \n",
    "                    reference_faces=face_imgs,\n",
    "                    need_detect=need_detect\n",
    "                    )\n",
    "\n",
    "\n",
    "# show the generated images\n",
    "img_size = imgs[0].size\n",
    "num_imgs = len(imgs)\n",
    "save_img = Image.new('RGB', (img_size[0] * (num_imgs + 1), img_size[1]))\n",
    "for i, img in enumerate(imgs):\n",
    "    save_img.paste(img, ((i + 1) * img_size[0], 0))\n",
    "\n",
    "# paste all four reference face imgs to the first\n",
    "\n",
    "resize_w = img_size[0] // 2\n",
    "resize_h = img_size[1] // 2\n",
    "\n",
    "for id, ref_img in enumerate(face_imgs):\n",
    "    # resize the ref_img keep the ratio to fit the size of (resize_w, resize_h)\n",
    "    w_ratio = resize_w / ref_img.size[0]\n",
    "    h_ratio = resize_h / ref_img.size[1]\n",
    "    ratio = min(w_ratio, h_ratio)\n",
    "    ref_img = ref_img.resize(\n",
    "        (int(ref_img.size[0] * ratio), int(ref_img.size[1] * ratio)))\n",
    "\n",
    "    if id < 2:\n",
    "        save_img.paste(ref_img, (id * resize_w, 0))\n",
    "    else:\n",
    "        save_img.paste(ref_img, ((id - 2) * resize_w, resize_h))\n",
    "\n",
    "display(save_img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#eren_jaeger.png\n",
    "\n",
    "face_imgs = [Image.open(f\"{package_dir}/examples/eren_jaeger.png\").convert(\"RGB\")]\n",
    "need_detect = True\n",
    "pos_prompt =  \"A handsome, attractive, sleek young man sitting on the beach, wearing black long trench coat, man bun hair,  heavily clouded, sunset, sea in the background\"\n",
    "# remove beard\n",
    "neg_prompt = \"beard\"\n",
    "# No face position\n",
    "face_bbox =[0., 0., 0., 0.] \n",
    "# bigger these three parameters leads to more fidelity but less diversity \n",
    "lamda_feat = 0.9\n",
    "face_guidence = 2\n",
    "step_to_launch_face_guidence = 600\n",
    "num_samples = 2\n",
    "steps = 50\n",
    "default_text_control_scale = 7.5\n",
    "\n",
    "default_seed = 0\n",
    "\n",
    "\n",
    "imgs = generate(pos_prompt=pos_prompt, \n",
    "                    neg_prompt=neg_prompt, \n",
    "                    steps=steps, \n",
    "                    face_bbox=face_bbox,\n",
    "                    lamda_feat=lamda_feat, \n",
    "                    face_guidence=face_guidence, \n",
    "                    num_sample=num_samples, \n",
    "                    text_control_scale=default_text_control_scale, \n",
    "                    seed=default_seed, \n",
    "                    step_to_launch_face_guidence=step_to_launch_face_guidence, \n",
    "                    reference_faces=face_imgs,\n",
    "                    need_detect=need_detect\n",
    "                    )\n",
    "\n",
    "\n",
    "# show the generated images\n",
    "img_size = imgs[0].size\n",
    "num_imgs = len(imgs)\n",
    "save_img = Image.new('RGB', (img_size[0] * (num_imgs + 1), img_size[1]))\n",
    "for i, img in enumerate(imgs):\n",
    "    save_img.paste(img, ((i + 1) * img_size[0], 0))\n",
    "\n",
    "# paste all four reference face imgs to the first\n",
    "\n",
    "resize_w = img_size[0] // 2\n",
    "resize_h = img_size[1] // 2\n",
    "\n",
    "for id, ref_img in enumerate(face_imgs):\n",
    "    # resize the ref_img keep the ratio to fit the size of (resize_w, resize_h)\n",
    "    w_ratio = resize_w / ref_img.size[0]\n",
    "    h_ratio = resize_h / ref_img.size[1]\n",
    "    ratio = min(w_ratio, h_ratio)\n",
    "    ref_img = ref_img.resize(\n",
    "        (int(ref_img.size[0] * ratio), int(ref_img.size[1] * ratio)))\n",
    "\n",
    "    if id < 2:\n",
    "        save_img.paste(ref_img, (id * resize_w, 0))\n",
    "    else:\n",
    "        save_img.paste(ref_img, ((id - 2) * resize_w, resize_h))\n",
    "\n",
    "display(save_img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ordinary people\n",
    "\n",
    "face_imgs = [Image.open(f\"{package_dir}/examples/man_face/{i+1}.png\").convert(\"RGB\") for i in range(4)]\n",
    "need_detect = True\n",
    "pos_prompt =  \"An handsome young man, with cowboy hat, long hair, full body, standing in the forest, sunset\"\n",
    "# remove beard\n",
    "neg_prompt = \"beard\"\n",
    "# No face position\n",
    "face_bbox =[0., 0., 0., 0.] \n",
    "# bigger these three parameters leads to more fidelity but less diversity \n",
    "lamda_feat = 0.85\n",
    "face_guidence = 2\n",
    "step_to_launch_face_guidence = 600\n",
    "num_samples = 2\n",
    "steps = 50\n",
    "default_text_control_scale = 7.5\n",
    "\n",
    "default_seed = 0\n",
    "\n",
    "\n",
    "imgs = generate(pos_prompt=pos_prompt, \n",
    "                    neg_prompt=neg_prompt, \n",
    "                    steps=steps, \n",
    "                    face_bbox=face_bbox,\n",
    "                    lamda_feat=lamda_feat, \n",
    "                    face_guidence=face_guidence, \n",
    "                    num_sample=num_samples,\n",
    "                    text_control_scale=default_text_control_scale, \n",
    "                    seed=default_seed, \n",
    "                    step_to_launch_face_guidence=step_to_launch_face_guidence, \n",
    "                    reference_faces=face_imgs,\n",
    "                    need_detect=need_detect\n",
    "                    )\n",
    "\n",
    "\n",
    "# show the generated images\n",
    "img_size = imgs[0].size\n",
    "num_imgs = len(imgs)\n",
    "save_img = Image.new('RGB', (img_size[0] * (num_imgs + 1), img_size[1]))\n",
    "for i, img in enumerate(imgs):\n",
    "    save_img.paste(img, ((i + 1) * img_size[0], 0))\n",
    "\n",
    "# paste all four reference face imgs to the first\n",
    "\n",
    "resize_w = img_size[0] // 2\n",
    "resize_h = img_size[1] // 2\n",
    "\n",
    "for id, ref_img in enumerate(face_imgs):\n",
    "    # resize the ref_img keep the ratio to fit the size of (resize_w, resize_h)\n",
    "    w_ratio = resize_w / ref_img.size[0]\n",
    "    h_ratio = resize_h / ref_img.size[1]\n",
    "    ratio = min(w_ratio, h_ratio)\n",
    "    ref_img = ref_img.resize(\n",
    "        (int(ref_img.size[0] * ratio), int(ref_img.size[1] * ratio)))\n",
    "\n",
    "    if id < 2:\n",
    "        save_img.paste(ref_img, (id * resize_w, 0))\n",
    "    else:\n",
    "        save_img.paste(ref_img, ((id - 2) * resize_w, resize_h))\n",
    "\n",
    "display(save_img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ordinary people\n",
    "\n",
    "face_imgs = [Image.open(f\"{package_dir}/examples/woman_face/{i+1}.png\").convert(\"RGB\") for i in range(4)]\n",
    "need_detect = True\n",
    "pos_prompt =  'A beautiful young woman with short curly hair in the garden holding a flower'\n",
    "# remove beard\n",
    "neg_prompt = None\n",
    "# No face position\n",
    "face_bbox =[0., 0., 0., 0.] \n",
    "# bigger these three parameters leads to more fidelity but less diversity \n",
    "lamda_feat = 1\n",
    "face_guidence = 2.3\n",
    "step_to_launch_face_guidence = 600\n",
    "num_samples = 2\n",
    "steps = 50\n",
    "default_text_control_scale = 7.5\n",
    "\n",
    "default_seed = 0\n",
    "\n",
    "\n",
    "imgs = generate(pos_prompt=pos_prompt, \n",
    "                    neg_prompt=neg_prompt, \n",
    "                    steps=steps, \n",
    "                    face_bbox=face_bbox,\n",
    "                    lamda_feat=lamda_feat, \n",
    "                    face_guidence=face_guidence, \n",
    "                    num_sample=num_samples, \n",
    "                    text_control_scale=default_text_control_scale, \n",
    "                    seed=default_seed, \n",
    "                    step_to_launch_face_guidence=step_to_launch_face_guidence, \n",
    "                    reference_faces=face_imgs,\n",
    "                    need_detect=need_detect\n",
    "                    )\n",
    "\n",
    "\n",
    "# show the generated images\n",
    "img_size = imgs[0].size\n",
    "num_imgs = len(imgs)\n",
    "save_img = Image.new('RGB', (img_size[0] * (num_imgs + 1), img_size[1]))\n",
    "for i, img in enumerate(imgs):\n",
    "    save_img.paste(img, ((i + 1) * img_size[0], 0))\n",
    "\n",
    "# paste all four reference face imgs to the first\n",
    "\n",
    "resize_w = img_size[0] // 2\n",
    "resize_h = img_size[1] // 2\n",
    "\n",
    "for id, ref_img in enumerate(face_imgs):\n",
    "    # resize the ref_img keep the ratio to fit the size of (resize_w, resize_h)\n",
    "    w_ratio = resize_w / ref_img.size[0]\n",
    "    h_ratio = resize_h / ref_img.size[1]\n",
    "    ratio = min(w_ratio, h_ratio)\n",
    "    ref_img = ref_img.resize(\n",
    "        (int(ref_img.size[0] * ratio), int(ref_img.size[1] * ratio)))\n",
    "\n",
    "    if id < 2:\n",
    "        save_img.paste(ref_img, (id * resize_w, 0))\n",
    "    else:\n",
    "        save_img.paste(ref_img, ((id - 2) * resize_w, resize_h))\n",
    "\n",
    "display(save_img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# details\n",
    "\n",
    "\n",
    "face_imgs = [Image.open(f\"{package_dir}/examples/details_face/{i+1}.jpeg\").convert(\"RGB\") for i in range(4)]\n",
    "need_detect = True\n",
    "pos_prompt =  'A beautiful young woman stands in the street,  wearing earing and white skirt and  hat, thin body, sunny day'\n",
    "# remove beard\n",
    "neg_prompt = 'Bangs'\n",
    "# left top corner\n",
    "face_bbox =  [0.1, 0.1, 0.5, 0.5]\n",
    "# bigger these three parameters leads to more fidelity but less diversity \n",
    "\n",
    "lamda_feat = 1.3\n",
    "face_guidence = 3.2\n",
    "step_to_launch_face_guidence = 800\n",
    "\n",
    "steps = 50\n",
    "default_text_control_scale = 8\n",
    "\n",
    "default_seed = 0\n",
    "num_samples = 2\n",
    "\n",
    "imgs = generate(pos_prompt=pos_prompt, \n",
    "                    neg_prompt=neg_prompt, \n",
    "                    steps=steps, \n",
    "                    face_bbox=face_bbox,\n",
    "                    lamda_feat=lamda_feat, \n",
    "                    face_guidence=face_guidence, \n",
    "                    num_sample=num_samples, \n",
    "                    text_control_scale=default_text_control_scale, \n",
    "                    seed=default_seed, \n",
    "                    step_to_launch_face_guidence=step_to_launch_face_guidence, \n",
    "                    reference_faces=face_imgs,\n",
    "                    need_detect=need_detect\n",
    "                    )\n",
    "\n",
    "\n",
    "# show the generated images\n",
    "img_size = imgs[0].size\n",
    "num_imgs = len(imgs)\n",
    "save_img = Image.new('RGB', (img_size[0] * (num_imgs + 1), img_size[1]))\n",
    "for i, img in enumerate(imgs):\n",
    "    save_img.paste(img, ((i + 1) * img_size[0], 0))\n",
    "\n",
    "# paste all four reference face imgs to the first\n",
    "\n",
    "resize_w = img_size[0] // 2\n",
    "resize_h = img_size[1] // 2\n",
    "\n",
    "for id, ref_img in enumerate(face_imgs):\n",
    "    # resize the ref_img keep the ratio to fit the size of (resize_w, resize_h)\n",
    "    w_ratio = resize_w / ref_img.size[0]\n",
    "    h_ratio = resize_h / ref_img.size[1]\n",
    "    ratio = min(w_ratio, h_ratio)\n",
    "    ref_img = ref_img.resize(\n",
    "        (int(ref_img.size[0] * ratio), int(ref_img.size[1] * ratio)))\n",
    "\n",
    "    if id < 2:\n",
    "        save_img.paste(ref_img, (id * resize_w, 0))\n",
    "    else:\n",
    "        save_img.paste(ref_img, ((id - 2) * resize_w, resize_h))\n",
    "\n",
    "display(save_img)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ranni",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
